import json


# 加载问题库和结果文件
with open("/home/test/yxl/MCoT/data/scienceqa/problems.json", "r") as f:
    questions = json.load(f)

with open("/home/test/yxl/MCoT/sqa/results/mistral-small3.1_test/exp1_test_QCM-ALE_seed_3_4random_1.json", "r") as f:
    results_data = json.load(f)


# 筛选自然科学的question_id
# grades = {"grade1", "grade2", "grade3", "grade4", "grade5", "grade6"}
grades = {"grade7", "grade8", "grade9", "grade10", "grade11", "grade12"}
natural_science_test_questions = {
    qid: details for qid, details in questions.items()
    # if details.get("subject") == "language science"
    if details.get("grade") in grades
    and details.get("split") == "test"
}

# 统计正确数和总题数
total = len(natural_science_test_questions)

all_accuracies = []
for run_key in results_data["all_results"]:
    correct = 0
    run = results_data["all_results"][run_key]
    results = run["results"]
    for qid, details in natural_science_test_questions.items():
        # 获取正确答案（problems.json中的answer字段）
        true_answer = details["answer"]
        # 获取模型预测结果（results中的值，需转为整数）
        # 注意：假设results中的键是字符串，与questions.json的qid类型一致
        predicted_answer = results.get(qid, -1)  # -1表示未找到预测结果
        # 判断预测是否正确
        if predicted_answer == true_answer:
            correct += 1

    accuracy = (correct / total) * 100 if total > 0 else 0
    all_accuracies.append(accuracy)


average_accuracy = sum(all_accuracies) / len(all_accuracies)


print(f"准确率: {average_accuracy:.2f}%")